{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-8ecFD9cgjM9"
   },
   "source": [
    "# RNN practicals\n",
    "\n",
    "This jupyter notebook allows you to reproduce and explore the results presented in the [lecture on RNN](https://dataflowr.github.io/slides/module11.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Pc29CvyHgjNE"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from collections import OrderedDict\n",
    "import scipy.special\n",
    "from scipy.special import binom\n",
    "import matplotlib.pyplot as plt\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "utb4MovzgjNe"
   },
   "outputs": [],
   "source": [
    "def running_mean(x, N):\n",
    "    cumsum = np.cumsum(np.insert(x, 0, 0)) \n",
    "    return (cumsum[N:] - cumsum[:-N]) / float(N)\n",
    "\n",
    "def Catalan(k):\n",
    "    return binom(2*k,k)/(k+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JIGrOVLSgjNx"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "use_gpu = torch.cuda.is_available()\n",
    "def gpu(tensor, gpu=use_gpu):\n",
    "    if gpu:\n",
    "        return tensor.cuda()\n",
    "    else:\n",
    "        return tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1Ytku22GgjOB"
   },
   "source": [
    "## 1. Generation a dataset\n",
    "\n",
    "We have a problem, where we need to generate a dataset made of valid parenthesis strings but also invalid parenthesis string. You can skip to the end of this section to see how parenthesis strings are generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QfpbFNT1gjOF"
   },
   "outputs": [],
   "source": [
    "seq_max_len = 20\n",
    "seq_min_len = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "uHAs_VxrgjOS"
   },
   "source": [
    "### generating positive examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DKlEMULfgjOX"
   },
   "outputs": [],
   "source": [
    "# convention: +1 opening parenthesis and -1 closing parenthesis\n",
    "\n",
    "def all_parent(n, a, k=-1):\n",
    "    global res\n",
    "    if k==n-1 and sum(a) == 0:\n",
    "        res.append(a.copy())\n",
    "    elif k==n-1:\n",
    "        pass\n",
    "    else:\n",
    "        k += 1\n",
    "        if sum(a) > 0:\n",
    "            a[k] = 1\n",
    "            all_parent(n,a,k)\n",
    "        \n",
    "            a[k] = -1\n",
    "            all_parent(n,a,k)\n",
    "            a[k] = 0\n",
    "        else:\n",
    "            a[k] = 1\n",
    "            all_parent(n,a,k)\n",
    "            a[k] = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dFcYswvXgjOl"
   },
   "source": [
    "### generating negative examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZXm1-wS9gjOp"
   },
   "outputs": [],
   "source": [
    "def all_parent_mistake(n, a, k=-1):\n",
    "    global res\n",
    "    if k==n-1 and sum(a) >= -1 and sum(a) <= 1 and min(np.cumsum(a))<0:\n",
    "        res.append(a.copy())\n",
    "    elif sum(a) > n-k:\n",
    "        pass\n",
    "    elif k==n-1:\n",
    "        pass\n",
    "    else:\n",
    "        k += 1\n",
    "        if sum(a) >= -1 and k != 0:\n",
    "            a[k] = 1\n",
    "            all_parent_mistake(n,a,k)\n",
    "        \n",
    "            a[k] = -1\n",
    "            all_parent_mistake(n,a,k)\n",
    "            a[k] = 0\n",
    "        else:\n",
    "            a[k] = 1\n",
    "            all_parent_mistake(n,a,k)\n",
    "            a[k] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ePYlXZaYgjO6"
   },
   "outputs": [],
   "source": [
    "# numbering the parentheses\n",
    "# example: seq of len 6\n",
    "# ( ( ( ) ) ) \n",
    "# 0 1 2 4 5 6\n",
    "# we always have ( + ) = seq_len\n",
    "# 'wrong' parentheses are always closing and numbered as:\n",
    "# ) )\n",
    "# 7 8\n",
    "\n",
    "def reading_par(l, n):\n",
    "    res = [0]*len(l)\n",
    "    s = []\n",
    "    n_plus = -1\n",
    "    n_moins = n+1\n",
    "    c = 0\n",
    "    for i in l:\n",
    "        if i == 1:\n",
    "            n_plus += 1\n",
    "            s.append(n_plus)\n",
    "            res[c] = n_plus\n",
    "            c += 1\n",
    "        else:\n",
    "            try:\n",
    "                res[c] = n-s.pop()\n",
    "            except:\n",
    "                res[c] = n_moins\n",
    "                n_moins += 1\n",
    "            c += 1\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "dA0MR3f7gjPN"
   },
   "outputs": [],
   "source": [
    "all_par = OrderedDict()\n",
    "for n in range(seq_min_len,seq_max_len+1,2):\n",
    "    a = [0]*n\n",
    "    res = []\n",
    "    all_parent(n=n,a=a,k=-1)\n",
    "    all_par[n] = [reading_par(k,n) for k in res]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JfMc7FvagjPi"
   },
   "outputs": [],
   "source": [
    "all_par_mist = OrderedDict()\n",
    "for n in range(seq_min_len,seq_max_len+1,2):\n",
    "    a = [0]*n\n",
    "    res = []\n",
    "    all_parent_mistake(n=n,a=a,k=-1)\n",
    "    all_par_mist[n] = [reading_par(k,n) for k in res]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7PtA8pQIgjP2",
    "outputId": "fd277dac-b0c6-4269-c416-2a7c05428209"
   },
   "outputs": [],
   "source": [
    "all_par[6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "26eGcBytgjQG",
    "outputId": "0ec9a43b-6c77-44e7-d97a-21695137f692"
   },
   "outputs": [],
   "source": [
    "all_par_mist[6]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "oCo1lHQpgjQT"
   },
   "source": [
    "### number of negative examples by length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mdr-aeg8gjQX"
   },
   "outputs": [],
   "source": [
    "long_mist = {i:len(l) for (i,l) in zip(all_par_mist.keys(),all_par_mist.values())}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7gWfeHiEgjQf",
    "outputId": "b1f5969c-589c-4994-d272-e8a1fd527a2f"
   },
   "outputs": [],
   "source": [
    "long_mist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "zq8iVHxDgjQm"
   },
   "source": [
    "### number of positive examples by length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kxzqq0JCgjQo"
   },
   "outputs": [],
   "source": [
    "Catalan_num = {i:len(l) for (i,l) in zip(all_par.keys(),all_par.values())}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WmRoddtQgjQv",
    "outputId": "a14f9876-a8cd-4660-a10c-e7e5a8dce013"
   },
   "outputs": [],
   "source": [
    "Catalan_num"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nzCECk4EgjQ5"
   },
   "source": [
    "Sanity check, see [Catalan numbers](https://en.wikipedia.org/wiki/Catalan_number)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8Thfdkm8gjQ7",
    "outputId": "6ae694e6-f77c-4190-bb81-fcc9cd7bb8f5"
   },
   "outputs": [],
   "source": [
    "[(2*i,Catalan(i)) for i  in range(2,int(seq_max_len/2)+1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qri2vYemgjRF",
    "outputId": "96ac241b-07af-4734-8a02-c368db5b7b91"
   },
   "outputs": [],
   "source": [
    "# nombre de suites correctes de longueur entre 4 et 10, alphabet de taille nb_symbol.\n",
    "nb_symbol = 10\n",
    "np.sum([Catalan(i)*int(nb_symbol/2)**i for i in range(2,int(seq_max_len/2)+1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LVc13MYfgjRR"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import torch\n",
    "\n",
    "class SequenceGenerator():\n",
    "    def __init__(self, nb_symbol = 10, seq_min_len = 4, seq_max_len = 10):\n",
    "        self.nb_symbol = nb_symbol\n",
    "        self.seq_min_len = seq_min_len\n",
    "        self.seq_max_len = seq_max_len\n",
    "        self.population = [i for i in range(int(nb_symbol/2))]\n",
    "                \n",
    "    def generate_pattern(self):\n",
    "        len_r = random.randint(self.seq_min_len/2,self.seq_max_len/2)\n",
    "        pattern = random.choices(self.population,k=len_r)\n",
    "        return pattern + pattern[::-1]\n",
    "    \n",
    "    def generate_pattern_parenthesis(self, len_r = None):\n",
    "        if len_r == None:\n",
    "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
    "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
    "        ind_r = random.randint(0,Catalan_num[len_r]-1)\n",
    "        res = [pattern[i] if i <= len_r/2 else self.nb_symbol-1-pattern[len_r-i] for i in all_par[len_r][ind_r]]\n",
    "        return res\n",
    "    \n",
    "    def generate_parenthesis_false(self):\n",
    "        len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
    "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
    "        ind_r = random.randint(0,long_mist[len_r]-1)\n",
    "        res = [pattern[i] if i <= len_r/2 \n",
    "               else  self.nb_symbol-1-pattern[len_r-i] if i<= len_r \n",
    "               else self.nb_symbol-1-pattern[i-len_r] for i in all_par_mist[len_r][ind_r]]\n",
    "        return res\n",
    "    \n",
    "    def generate_hard_parenthesis(self, len_r = None):\n",
    "        if len_r == None:\n",
    "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
    "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
    "        ind_r = random.randint(0,Catalan_num[len_r]-1)\n",
    "        res = [pattern[i] if i <= len_r/2 else self.nb_symbol-1-pattern[len_r-i] for i in all_par[len_r][ind_r]]\n",
    "        \n",
    "        if len_r == None:\n",
    "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
    "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
    "        ind_r = random.randint(0,Catalan_num[len_r]-1)\n",
    "        res2 = [pattern[i] if i <= len_r/2 else self.nb_symbol-1-pattern[len_r-i] for i in all_par[len_r][ind_r]]\n",
    "        return res + res2\n",
    "    \n",
    "    def generate_hard_nonparenthesis(self, len_r = None):\n",
    "        if len_r == None:\n",
    "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
    "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
    "        ind_r = random.randint(0,long_mist[len_r]-1)\n",
    "        res = [pattern[i] if i <= len_r/2 \n",
    "               else  self.nb_symbol-1-pattern[len_r-i] if i<= len_r \n",
    "               else self.nb_symbol-1-pattern[i-len_r] for i in all_par_mist[len_r][ind_r]]\n",
    "        \n",
    "        if len_r == None:\n",
    "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
    "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
    "        ind_r = random.randint(0,Catalan_num[len_r]-1)\n",
    "        res2 = [pattern[i] if i <= len_r/2 else self.nb_symbol-1-pattern[len_r-i] for i in all_par[len_r][ind_r]]\n",
    "        return  res +[self.nb_symbol-1-pattern[0]]+ res2\n",
    "        \n",
    "    def generate_false(self):\n",
    "        popu = [i for i in range(nb_symbol)]\n",
    "        len = random.randint(self.seq_min_len/2,self.seq_max_len/2)\n",
    "        return random.choices(popu,k=len) + random.choices(popu,k=len)\n",
    "    \n",
    "    def generate_label(self, x):\n",
    "        l = int(len(x)/2)\n",
    "        return 1 if x[:l] == x[:l-1:-1] else 0\n",
    "    \n",
    "    def generate_label_parenthesis(self, x):\n",
    "        s = []\n",
    "        label = 1\n",
    "        lenx = len(x)\n",
    "        for i in x:\n",
    "            if s == [] and i < self.nb_symbol/2:\n",
    "                s.append(i)\n",
    "            elif s == [] and i >= self.nb_symbol/2:\n",
    "                label = 0\n",
    "                break\n",
    "            elif i == self.nb_symbol-1-s[-1]:\n",
    "                s.pop()\n",
    "            else:\n",
    "                s.append(i)\n",
    "        if s != []:\n",
    "            label = 0\n",
    "        return label\n",
    "    \n",
    "    def one_hot(self,seq):\n",
    "        one_hot_seq = []\n",
    "        for s in seq:\n",
    "            one_hot = [0 for _ in range(self.nb_symbol)]\n",
    "            one_hot[s] = 1\n",
    "            one_hot_seq.append(one_hot)\n",
    "        return one_hot_seq\n",
    "    \n",
    "    def generate_input(self, len_r = None, true_parent = False, hard_false = True):\n",
    "        if true_parent:\n",
    "            seq = self.generate_pattern_parenthesis(len_r)\n",
    "        elif bool(random.getrandbits(1)):\n",
    "            seq = self.generate_pattern_parenthesis(len_r)\n",
    "        else:\n",
    "            if hard_false:\n",
    "                seq = self.generate_parenthesis_false()\n",
    "            else:\n",
    "                seq = self.generate_false()\n",
    "        return gpu(torch.from_numpy(np.array(self.one_hot(seq))).type(torch.FloatTensor)), gpu(torch.from_numpy(np.array([self.generate_label_parenthesis(seq)])))\n",
    "\n",
    "    def generate_input_hard(self,true_parent = False):\n",
    "        if true_parent:\n",
    "            seq = self.generate_hard_parenthesis(self.seq_max_len)\n",
    "        elif bool(random.getrandbits(1)):\n",
    "            seq = self.generate_hard_parenthesis(self.seq_max_len)\n",
    "        else:\n",
    "            seq = self.generate_hard_nonparenthesis(self.seq_max_len)\n",
    "            \n",
    "        return gpu(torch.from_numpy(np.array(self.one_hot(seq))).type(torch.FloatTensor)), gpu(torch.from_numpy(np.array([self.generate_label_parenthesis(seq)])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ivBE7UIYgjRc"
   },
   "outputs": [],
   "source": [
    "nb_symbol = 10\n",
    "generator = SequenceGenerator(nb_symbol = nb_symbol, seq_min_len = seq_min_len, seq_max_len = seq_max_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Dj5YgzvXgjRq",
    "outputId": "96fb8b2b-0fb9-4092-a48a-bd904ed4541d"
   },
   "outputs": [],
   "source": [
    "generator.generate_pattern_parenthesis()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lbKlAmE_gjRz"
   },
   "outputs": [],
   "source": [
    "x = generator.generate_parenthesis_false()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2kdSGWGCgjR7",
    "outputId": "cb3f1cb0-fd6d-4d16-fb51-fd628dc01ebb"
   },
   "outputs": [],
   "source": [
    "generator.generate_label_parenthesis(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4hL2UnY6gjSC",
    "outputId": "dc6f9546-b4b6-4fae-ab27-942ea91d6163"
   },
   "outputs": [],
   "source": [
    "generator.generate_input()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hkOTjH01gjSN"
   },
   "source": [
    "## 2. First RNN: [Elman network](https://mlelarge.github.io/dataflowr-slides/PlutonAI/lesson7.html#16)\n",
    "\n",
    "Initial hidden state: $h_0 =0$\n",
    "\n",
    "Update:\n",
    "\n",
    "$$\n",
    "h_t = \\mathrm{ReLU}(W_{xh} x_t + W_{hh} h_{t-1} + b_h)\n",
    "$$\n",
    "\n",
    "Final prediction:\n",
    "\n",
    "$$\n",
    "y_T = W_{hy} h_T + b_y.\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bLIQlP1QgjSP"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class RecNet(nn.Module):\n",
    "    def __init__(self, dim_input=10, dim_recurrent=50, dim_output=2):\n",
    "        super(RecNet, self).__init__()\n",
    "        self.fc_x2h = nn.Linear(dim_input, dim_recurrent)\n",
    "        self.fc_h2h = nn.Linear(dim_recurrent, dim_recurrent, bias = False)\n",
    "        self.fc_h2y = nn.Linear(dim_recurrent, dim_output)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = x.new_zeros(1, self.fc_h2y.weight.size(1))\n",
    "        for t in range(x.size(0)):\n",
    "            h = torch.relu(self.fc_x2h(x[t,:]) + self.fc_h2h(h))    \n",
    "        return self.fc_h2y(h)\n",
    "    \n",
    "RNN = gpu(RecNet(dim_input = nb_symbol))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "AcbhSU5ggjSW",
    "outputId": "49f5cf6b-b2a1-4319-bd29-5eb79a573c69"
   },
   "outputs": [],
   "source": [
    "cross_entropy = nn.CrossEntropyLoss()\n",
    "learning_rate = 1e-3\n",
    "optimizer = torch.optim.Adam(RNN.parameters(),lr=learning_rate)\n",
    "nb_train = 40000\n",
    "loss_t = []\n",
    "corrects =[]\n",
    "labels = []\n",
    "start = time.time()\n",
    "for k in range(nb_train):\n",
    "    x,l = generator.generate_input(hard_false = False)\n",
    "    y = RNN(x)\n",
    "    loss = cross_entropy(y,l)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    corrects.append(preds.item() == l.data.item())\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    loss_t.append(loss)\n",
    "    labels.append(l.data)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NilV4syzgjSi",
    "outputId": "d4f030d9-333d-44a4-9005-af6da9feaa8a"
   },
   "outputs": [],
   "source": [
    "plt.plot(running_mean(loss_t,int(nb_train/100)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LINKeGNjgjSs",
    "outputId": "f459edff-08b3-46b9-9bac-d194c480bf4c"
   },
   "outputs": [],
   "source": [
    "plt.plot(running_mean(corrects,int(nb_train/100)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "O2A1jCLJgjS5"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "corrects_test =[]\n",
    "labels_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input(len_r=seq_max_len,true_parent=True)\n",
    "    y = RNN(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    corrects_test.append(preds.item() == l.data.item())\n",
    "    labels_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TdgtK_tGZvJR"
   },
   "source": [
    "Accuracy on valid parenthesis strings only:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "z41nP-UhgjS-",
    "outputId": "6526283e-2cba-43a8-a70c-69de63fa5004"
   },
   "outputs": [],
   "source": [
    "np.sum(corrects_test)/nb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lzOpkaLigjTD"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "corrects_test =[]\n",
    "labels_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input(len_r=seq_max_len, hard_false = True)\n",
    "    y = RNN(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    corrects_test.append(preds.item() == l.data.item())\n",
    "    labels_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FjT6sbUSZvJV"
   },
   "source": [
    "Accuracy on a test set (similar to the training set):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Geelj-5XgjTM",
    "outputId": "d39f0642-f8c5-4454-cd1a-c67268520c81"
   },
   "outputs": [],
   "source": [
    "np.sum(corrects_test)/nb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "crSKTEJHgjTU"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "correctsh_test =[]\n",
    "labelsh_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input_hard()\n",
    "    y = RNN(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctsh_test.append(preds.item() == l.data.item())\n",
    "    labelsh_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rioM-_UkZvJa"
   },
   "source": [
    "Accuracy on a test set of hard instances, i.e. instances longer than those seen during the training :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "br6HRXOhgjTc",
    "outputId": "6af91642-e7e0-4fba-d71d-3a8968889385"
   },
   "outputs": [],
   "source": [
    "np.sum(correctsh_test)/nb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "I5q30nkqgjTn"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "correctsh_test =[]\n",
    "labelsh_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input_hard(true_parent=True)\n",
    "    y = RNN(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctsh_test.append(preds.item() == l.data.item())\n",
    "    labelsh_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bOPC3MomZvJg"
   },
   "source": [
    "It looks like our network is always prediciting a valid label for long sequences:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "L6ZM_4c2gjTw",
    "outputId": "354163be-56a8-424c-a4c8-91dc01063507"
   },
   "outputs": [],
   "source": [
    "np.sum(correctsh_test)/nb_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wvaIEg6igjT3"
   },
   "source": [
    "## 3. [RNN with Gating](https://mlelarge.github.io/dataflowr-slides/PlutonAI/lesson7.html#20)\n",
    "\n",
    "$$\n",
    "\\overline{h}_t = \\mathrm{ReLU}(W_{xh} x_t + W_{hh} h_{t-1} + b_h)\n",
    "$$\n",
    "Forget gate:\n",
    "$$\n",
    "z_t = \\mathrm{sigm}(W_{xz} x_t + W_{hz}h_{t-1}+b_z)\n",
    "$$\n",
    "Hidden state:\n",
    "$$\n",
    "h_t = z_t\\odot h_{t-1} +(1-z_t) \\odot \\overline{h}_t\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kOiH5WsLgjT5"
   },
   "outputs": [],
   "source": [
    "class RecNetGating(nn.Module):\n",
    "    def __init__(self, dim_input=10, dim_recurrent=50, dim_output=2):\n",
    "        super(RecNetGating, self).__init__()\n",
    "        self.fc_x2h = nn.Linear(dim_input, dim_recurrent)\n",
    "        self.fc_h2h = nn.Linear(dim_recurrent, dim_recurrent, bias = False)\n",
    "        self.fc_x2z = nn.Linear(dim_input, dim_recurrent)\n",
    "        self.fc_h2z = nn.Linear(dim_recurrent,dim_recurrent, bias = False)\n",
    "        self.fc_h2y = nn.Linear(dim_recurrent, dim_output)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = x.new_zeros(1, self.fc_h2y.weight.size(1))\n",
    "        for t in range(x.size(0)):\n",
    "            z = torch.sigmoid(self.fc_x2z(x[t,:])+self.fc_h2z(h))\n",
    "            hb = torch.relu(self.fc_x2h(x[t,:]) + self.fc_h2h(h))\n",
    "            h = z * h + (1-z) * hb   \n",
    "        return self.fc_h2y(h)    \n",
    "    \n",
    "RNNG = gpu(RecNetGating(dim_input = nb_symbol))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NZuKhEg-gjUB",
    "outputId": "c88b2de9-4e54-4e0c-faa2-fc8cd73eff6f"
   },
   "outputs": [],
   "source": [
    "optimizerG = torch.optim.Adam(RNNG.parameters(),lr=1e-3)\n",
    "loss_tG = []\n",
    "correctsG =[]\n",
    "labelsG = []\n",
    "start = time.time()\n",
    "for k in range(nb_train):\n",
    "    x,l = generator.generate_input(hard_false = False)\n",
    "    y = RNNG(x)\n",
    "    loss = cross_entropy(y,l)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctsG.append(preds.item() == l.data.item())\n",
    "    optimizerG.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizerG.step()\n",
    "    loss_tG.append(loss)\n",
    "    labelsG.append(l.item())\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6BbKxWkTgjUI",
    "outputId": "672c25a9-2791-47df-fe73-e94bb500b388"
   },
   "outputs": [],
   "source": [
    "plt.plot(running_mean(loss_tG,int(nb_train/50)))\n",
    "plt.plot(running_mean(loss_t,int(nb_train/50)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Vjp8ybXqgjUO",
    "outputId": "ddd7febf-bebc-4c6a-c999-a37052cfaf5f"
   },
   "outputs": [],
   "source": [
    "plt.plot(running_mean(correctsG,int(nb_train/50)))\n",
    "plt.plot(running_mean(corrects,int(nb_train/50)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "opLeQwWogjUU"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "correctsG_test =[]\n",
    "labelsG_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input(len_r=seq_max_len,true_parent=True)\n",
    "    y = RNNG(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctsG_test.append(preds.item() == l.data.item())\n",
    "    labelsG_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wOL4IOo8ZvJx"
   },
   "source": [
    "Accuracy on valid parenthesis strings only:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "FKF7r8dSgjUc",
    "outputId": "3fe4e03e-4556-4b32-f5f7-68e4b60bae65"
   },
   "outputs": [],
   "source": [
    "np.sum(correctsG_test)/nb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IwQg9tXRgjUq"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "correctsG_test =[]\n",
    "labelsG_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input(len_r=seq_max_len, hard_false = True)\n",
    "    y = RNNG(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctsG_test.append(preds.item() == l.data.item())\n",
    "    labelsG_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OydDhUHiZvJ2"
   },
   "source": [
    "Accuracy on a test set (similar to the training set):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QdWoh4UTgjUv",
    "outputId": "aab623ca-43b6-457b-c00a-82795696dedd"
   },
   "outputs": [],
   "source": [
    "np.sum(correctsG_test)/nb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OkkC1_jMgjU2"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "correctshG_test =[]\n",
    "labelshG_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input_hard()\n",
    "    y = RNNG(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctshG_test.append(preds.item() == l.data.item())\n",
    "    labelshG_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EDiAHxCIZvJ6"
   },
   "source": [
    "Accuracy on a test set of hard instances, i.e. instances longer than those seen during the training :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kUMHjsXrgjVB",
    "outputId": "2cc93a81-1d95-48df-ab35-db938a4ecc9d"
   },
   "outputs": [],
   "source": [
    "np.sum(correctshG_test)/nb_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9GO-LItVgjVN"
   },
   "source": [
    "## 4. [LSTM](https://mlelarge.github.io/dataflowr-slides/PlutonAI/lesson7.html#27)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4gTATkvrgjVR"
   },
   "outputs": [],
   "source": [
    "class LSTMNet(nn.Module):\n",
    "    def __init__(self, dim_input=10, dim_recurrent=50, num_layers=4, dim_output=2):\n",
    "        super(LSTMNet, self).__init__()\n",
    "        self.lstm = nn.LSTM(input_size = dim_input, \n",
    "                           hidden_size = dim_recurrent,\n",
    "                           num_layers = num_layers)\n",
    "        self.fc_o2y = nn.Linear(dim_recurrent,dim_output)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        output, _ = self.lstm(x)\n",
    "        output = output.squeeze(1)\n",
    "        output = output.narrow(0, output.size(0)-1,1)\n",
    "        return self.fc_o2y(F.relu(output))\n",
    "    \n",
    "lstm = gpu(LSTMNet(dim_input = nb_symbol))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WkjccxdsgjVX"
   },
   "outputs": [],
   "source": [
    "x, l = generator.generate_input()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sgdPdLupgjVf",
    "outputId": "d4aa32eb-8c43-400e-e3a2-ad2845bc3610"
   },
   "outputs": [],
   "source": [
    "lstm(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Rp2KI_eSgjVk",
    "outputId": "7d07ebea-2c65-4a83-bda4-5c821b10bd28"
   },
   "outputs": [],
   "source": [
    "optimizerL = torch.optim.Adam(lstm.parameters(),lr=1e-3)\n",
    "loss_tL = []\n",
    "correctsL =[]\n",
    "labelsL = []\n",
    "start = time.time()\n",
    "for k in range(nb_train):\n",
    "    x,l = generator.generate_input(hard_false = False)\n",
    "    y = lstm(x)\n",
    "    loss = cross_entropy(y,l)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctsL.append(preds.item() == l.data.item())\n",
    "    optimizerL.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizerL.step()\n",
    "    loss_tL.append(loss)\n",
    "    labelsL.append(l.item())\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LwLenhrCgjVo",
    "outputId": "68edd7eb-a170-4a69-817f-2c11566a7bb4"
   },
   "outputs": [],
   "source": [
    "plt.plot(running_mean(loss_tL,int(nb_train/50)))\n",
    "plt.plot(running_mean(loss_tG,int(nb_train/50)))\n",
    "plt.plot(running_mean(loss_t,int(nb_train/50)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ng_wRELigjVt",
    "outputId": "7e6c73fd-5b3d-466a-8209-14445a5c1253"
   },
   "outputs": [],
   "source": [
    "plt.plot(running_mean(correctsL,int(nb_train/50)))\n",
    "plt.plot(running_mean(correctsG,int(nb_train/50)))\n",
    "plt.plot(running_mean(corrects,int(nb_train/50)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "oFfEMA89gjVw"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "correctsL_test =[]\n",
    "labelsL_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input(len_r=seq_max_len,true_parent=True)\n",
    "    y = lstm(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctsL_test.append(preds.item() == l.data.item())\n",
    "    labelsL_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sSH-Hg2fZvKP"
   },
   "source": [
    "Accuracy on valid parenthesis strings only:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mn4MCPRDgjV2",
    "outputId": "d368be32-95b1-4df4-b2f3-7d7c9d1e546d"
   },
   "outputs": [],
   "source": [
    "np.sum(correctsL_test)/nb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CIvIS-gEgjV-"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "correctsL_test =[]\n",
    "labelsL_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input(len_r=seq_max_len,true_parent=False,hard_false = True)\n",
    "    y = lstm(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctsL_test.append(preds.item() == l.data.item())\n",
    "    labelsL_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "h_GZSFHxZvKR"
   },
   "source": [
    "Accuracy on a test set (similar to the training set):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "b06Gp7lrgjWC",
    "outputId": "b5ca9136-7cd8-41bb-e3a4-bc09d48c7e50"
   },
   "outputs": [],
   "source": [
    "np.sum(correctsL_test)/nb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rniUHL4mgjWE"
   },
   "outputs": [],
   "source": [
    "nb_test = 1000\n",
    "correctshL_test =[]\n",
    "labelshL_test = []\n",
    "for k in range(nb_test):\n",
    "    x,l = generator.generate_input_hard()\n",
    "    y = lstm(x)\n",
    "    _,preds = torch.max(y.data,1)\n",
    "    correctshL_test.append(preds.item() == l.data.item())\n",
    "    labelshL_test.append(l.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tskqOH5vZvKV"
   },
   "source": [
    "Accuracy on a test set of hard instances, i.e. instances longer than those seen during the training :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Zh43kUcGgjWI",
    "outputId": "db1cff01-39c5-4553-a7d4-3601d85ceb9e"
   },
   "outputs": [],
   "source": [
    "np.sum(correctshL_test)/nb_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mWUSVdiggjWO"
   },
   "source": [
    "## 5. GRU\n",
    "\n",
    "Implement your RNN with a [GRU](https://pytorch.org/docs/stable/nn.html#gru)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JSH65GakgjWV"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OvHhG7KfgjWe"
   },
   "source": [
    "## 6. Explore!\n",
    "\n",
    "What are good negative examples?\n",
    "\n",
    "How to be sure that your network 'generalizes'?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lCVb66-xgjWj"
   },
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "include_colab_link": false,
   "name": "11_RNN.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
